{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# FlintModel"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "FlintModel is a necessary input to apply adversarial attack or generate robustness report.  textflint allows practitioners to customize target model, practitioners just need to wrap their own models through FlintModel and implement the corresponding interfaces. Thanks to [TextAttack](https://github.com/QData/TextAttack) for integrating various attack methods."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('/Users/wangxiao/code/python/RobustnessTool/TextRobustness/textrobustness')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## How to customize targe model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* You need provide the tokenizer object corresponding to the model, which is used to process the sample into the input of the model. This may include tokenize the text, and convert tokens to ids.\n",
    "\n",
    "* You need to provide your own model object to support the prediction function of the model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "e.g."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from textflint.model.flint_model.torch_model import TorchModel\n",
    "from textflint.model.test_model.textcnn_torch_model import TextCNNTorchModel\n",
    "from textflint.model.test_model.glove_embedding import GloveEmbedding\n",
    "from textflint.model.tokenizers.glove_tokenizer import GloveTokenizer\n",
    "\n",
    "class TextCNNTorch(TorchModel):\n",
    "    r\"\"\"\n",
    "    Model wrapper for TextCnn implemented by pytorch.\n",
    "\n",
    "    \"\"\"\n",
    "    def __init__(self):\n",
    "        glove_embedding = GloveEmbedding()\n",
    "        word2id = glove_embedding.word2id\n",
    "\n",
    "        super().__init__(\n",
    "            model=TextCNNTorchModel(\n",
    "                init_embedding=glove_embedding.embedding\n",
    "            ),\n",
    "            task='SA',\n",
    "            tokenizer=GloveTokenizer(\n",
    "                word_id_map=word2id,\n",
    "                unk_token_id=glove_embedding.oovid,\n",
    "                pad_token_id=glove_embedding.padid,\n",
    "                max_length=30\n",
    "            )\n",
    "        )\n",
    "        self.label2id = {\"positive\": 0, \"negative\": 1}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## How to implement the automatic evaluation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For testing the robustness of the model, users can test the generated samples through their own code, not necessarily using FlintModel. FlintModel provides verification metrics for most tasks, and its verification results can be directly used as input for subsequent report generation. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Users have to implement two functions.\n",
    "* **unzip_samples( )** function, which accept batch samples as input, and return (**batch input features, batch labels**), **input features** can directly pass to **__call__( )** to predict, while **labels** can be used to calculate metrics.\n",
    "* **__call__( )** function, which accept **batch input features** as input and predict **target label** . "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "e.g."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "    def unzip_samples(self, data_samples):\n",
    "        r\"\"\"\n",
    "        Unzip sample to input texts and labels.\n",
    "\n",
    "        :param list[Sample] data_samples: list of Samples\n",
    "        :return: (inputs_text), labels.\n",
    "\n",
    "        \"\"\"\n",
    "        x = []\n",
    "        y = []\n",
    "\n",
    "        for sample in data_samples:\n",
    "            x.append(sample['x'])\n",
    "            y.append(self.label2id[sample['y']])\n",
    "\n",
    "        return [x], y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "    def __call__(self, batch_texts):\n",
    "        r\"\"\"\n",
    "        Tokenize text, convert tokens to id and run the model.\n",
    "\n",
    "        :param batch_texts: (batch_size,) batch text input\n",
    "        :return: numpy.array()\n",
    "\n",
    "        \"\"\"\n",
    "        model_device = next(self.model.parameters()).device\n",
    "        inputs_ids = [self.encode(batch_text) for batch_text in batch_texts]\n",
    "        ids = torch.tensor(inputs_ids).to(model_device)\n",
    "\n",
    "        return self.model(ids).detach().cpu().numpy()\n",
    "    \n",
    "        def encode(self, inputs):\n",
    "        r\"\"\"\n",
    "        Tokenize inputs and convert it to ids.\n",
    "\n",
    "        :param inputs: model original input\n",
    "        :return: list of inputs ids\n",
    "\n",
    "        \"\"\"\n",
    "        return self.tokenizer.encode(inputs)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## How to implement adversarial attack "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "FlintModel is a necessary input for the generation of adversarial attack samples. Since textflint just support apply attack to four tasks, including 'SA', 'SM', 'NLI' and 'TC'. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Users have to implement two functions.\n",
    "* **unzip_samples( )** function, which accept batch samples as input, and return (**batch input features, batch labels**).\n",
    "* **get_model_grad( )** function, which accept input features as input, and return gradient of loss with respect to input tokens. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "e.g."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_model_grad(self, text_inputs, loss_fn=CrossEntropyLoss()):\n",
    "        r\"\"\"\n",
    "        Get gradient of loss with respect to input tokens.\n",
    "\n",
    "        :param str|[str] text_inputs: input string or input string list\n",
    "        :param torch.nn.Module loss_fn: loss function.\n",
    "            Default is `torch.nn.CrossEntropyLoss`\n",
    "        :return: Dict of ids, tokens, and gradient as numpy array.\n",
    "\n",
    "        \"\"\"\n",
    "        if not hasattr(self.model, \"get_input_embeddings\"):\n",
    "            raise AttributeError(\n",
    "                f\"{type(self.model)} must have method `get_input_embeddings` \"\n",
    "                f\"that returns `torch.nn.Embedding` object that represents \"\n",
    "                f\"input embedding layer\"\n",
    "            )\n",
    "\n",
    "        if not isinstance(loss_fn, torch.nn.Module):\n",
    "            raise ValueError(\"Loss function must be of type `torch.nn.Module`.\")\n",
    "\n",
    "        self.model.train()\n",
    "\n",
    "        embedding_layer = self.model.get_input_embeddings()\n",
    "        original_state = embedding_layer.weight.requires_grad\n",
    "        embedding_layer.weight.requires_grad = True\n",
    "\n",
    "        emb_grads = []\n",
    "\n",
    "        def grad_hook(module, grad_in, grad_out):\n",
    "            emb_grads.append(grad_out[0])\n",
    "\n",
    "        emb_hook = embedding_layer.register_backward_hook(grad_hook)\n",
    "        self.model.zero_grad()\n",
    "        model_device = next(self.model.parameters()).device\n",
    "\n",
    "        inputs_ids = self.encode(text_inputs)\n",
    "        ids = [torch.tensor(ids).to(model_device) for ids in inputs_ids]\n",
    "\n",
    "        predictions = self.model(text_inputs)\n",
    "\n",
    "        output = predictions.argmax(dim=1)\n",
    "        loss = loss_fn(predictions, output)\n",
    "        loss.backward()\n",
    "\n",
    "        # grad w.r.t to word embeddings\n",
    "        grad = torch.transpose(emb_grads[0], 0, 1)[0].cpu().numpy()\n",
    "\n",
    "        embedding_layer.weight.requires_grad = original_state\n",
    "        emb_hook.remove()\n",
    "        self.model.eval()\n",
    "\n",
    "        output = {\"ids\": ids[0].tolist(), \"gradient\": grad}\n",
    "\n",
    "        return output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "textflint provides a base class for PyTorch model which has implemented **get_model_grad( )** function. Take TextCNN's pytorch implementation as an example, and give a complete FlintModel example implementation. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "from textflint.model.flint_model.torch_model import TorchModel\n",
    "from textflint.model.test_model.textcnn_torch_model import TextCNNTorchModel\n",
    "from textflint.model.test_model.glove_embedding import GloveEmbedding\n",
    "from textflint.model.tokenizers.glove_tokenizer import GloveTokenizer\n",
    "\n",
    "\n",
    "class TextCNNTorch(TorchModel):\n",
    "    r\"\"\"\n",
    "    Model wrapper for TextCnn implemented by pytorch.\n",
    "\n",
    "    \"\"\"\n",
    "    def __init__(self):\n",
    "        glove_embedding = GloveEmbedding()\n",
    "        word2id = glove_embedding.word2id\n",
    "\n",
    "        super().__init__(\n",
    "            model=TextCNNTorchModel(\n",
    "                init_embedding=glove_embedding.embedding\n",
    "            ),\n",
    "            task='SA',\n",
    "            tokenizer=GloveTokenizer(\n",
    "                word_id_map=word2id,\n",
    "                unk_token_id=glove_embedding.oovid,\n",
    "                pad_token_id=glove_embedding.padid,\n",
    "                max_length=30\n",
    "            )\n",
    "        )\n",
    "        self.label2id = {\"positive\": 0, \"negative\": 1}\n",
    "\n",
    "    def __call__(self, batch_texts):\n",
    "        r\"\"\"\n",
    "        Tokenize text, convert tokens to id and run the model.\n",
    "\n",
    "        :param batch_texts: (batch_size,) batch text input\n",
    "        :return: numpy.array()\n",
    "\n",
    "        \"\"\"\n",
    "        model_device = next(self.model.parameters()).device\n",
    "        inputs_ids = [self.encode(batch_text) for batch_text in batch_texts]\n",
    "        ids = torch.tensor(inputs_ids).to(model_device)\n",
    "\n",
    "        return self.model(ids).detach().cpu().numpy()\n",
    "\n",
    "    def encode(self, inputs):\n",
    "        r\"\"\"\n",
    "        Tokenize inputs and convert it to ids.\n",
    "\n",
    "        :param inputs: model original input\n",
    "        :return: list of inputs ids\n",
    "\n",
    "        \"\"\"\n",
    "        return self.tokenizer.encode(inputs)\n",
    "\n",
    "    def unzip_samples(self, data_samples):\n",
    "        r\"\"\"\n",
    "        Unzip sample to input texts and labels.\n",
    "\n",
    "        :param list[Sample] data_samples: list of Samples\n",
    "        :return: (inputs_text), labels.\n",
    "\n",
    "        \"\"\"\n",
    "        x = []\n",
    "        y = []\n",
    "\n",
    "        for sample in data_samples:\n",
    "            x.append(sample['x'])\n",
    "            y.append(self.label2id[sample['y']])\n",
    "\n",
    "        return [x], y\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:.conda-torch] *",
   "language": "python",
   "name": "conda-env-.conda-torch-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
